{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem: Implement an RNN in PyTorch\n",
    "\n",
    "### Problem Statement\n",
    "You are tasked with implementing a **Recurrent Neural Network (RNN)** in PyTorch to process sequential data. The model should contain an RNN layer for handling sequential input and a fully connected layer to output the final predictions. Your goal is to complete the RNN model by defining the necessary layers and implementing the forward pass.\n",
    "\n",
    "### Requirements\n",
    "1. **Define the RNN Model**:\n",
    "   - Add an **RNN layer** to process sequential data.\n",
    "   - Add a **fully connected layer** to map the RNN output to the final prediction.\n",
    "\n",
    "### Constraints\n",
    "- Use appropriate configurations for the RNN layer, including hidden units and input/output sizes.\n",
    "\n",
    "\n",
    "<details>\n",
    "  <summary>💡 Hint</summary>\n",
    "  Add the RNN layer (self.rnn) and fully connected layer (self.fc) in RNNModel.__init__.\n",
    "  <br>\n",
    "  Implement the forward pass to process inputs through the RNN layer and fully connected layer.\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate synthetic sequential data\n",
    "torch.manual_seed(42)\n",
    "sequence_length = 10\n",
    "num_samples = 100\n",
    "\n",
    "# Create a sine wave dataset\n",
    "X = torch.linspace(0, 4 * 3.14159, steps=num_samples).unsqueeze(1)\n",
    "y = torch.sin(X)\n",
    "\n",
    "# Prepare data for RNN\n",
    "def create_in_out_sequences(data, seq_length):\n",
    "    in_seq = []\n",
    "    out_seq = []\n",
    "    for i in range(len(data) - seq_length):\n",
    "        in_seq.append(data[i:i + seq_length])\n",
    "        out_seq.append(data[i + seq_length])\n",
    "    return torch.stack(in_seq), torch.stack(out_seq)\n",
    "\n",
    "X_seq, y_seq = create_in_out_sequences(y, sequence_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the RNN Model\n",
    "# TODO: Add RNN layer, fully connected layer, and forward implementation\n",
    "class RNNModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        ...\n",
    "\n",
    "    def forward(self, x):\n",
    "        ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/500], Loss: 0.1162\n",
      "Epoch [2/500], Loss: 0.0031\n",
      "Epoch [3/500], Loss: 0.0095\n",
      "Epoch [4/500], Loss: 0.0002\n",
      "Epoch [5/500], Loss: 0.0000\n",
      "Epoch [6/500], Loss: 0.0268\n",
      "Epoch [7/500], Loss: 0.0165\n",
      "Epoch [8/500], Loss: 0.0763\n",
      "Epoch [9/500], Loss: 0.0163\n",
      "Epoch [10/500], Loss: 0.0008\n",
      "Epoch [11/500], Loss: 0.0009\n",
      "Epoch [12/500], Loss: 0.0009\n",
      "Epoch [13/500], Loss: 0.0016\n",
      "Epoch [14/500], Loss: 0.0000\n",
      "Epoch [15/500], Loss: 0.0008\n",
      "Epoch [16/500], Loss: 0.0003\n",
      "Epoch [17/500], Loss: 0.0004\n",
      "Epoch [18/500], Loss: 0.0028\n",
      "Epoch [19/500], Loss: 0.0062\n",
      "Epoch [20/500], Loss: 0.0022\n",
      "Epoch [21/500], Loss: 0.0046\n",
      "Epoch [22/500], Loss: 0.0066\n",
      "Epoch [23/500], Loss: 0.0072\n",
      "Epoch [24/500], Loss: 0.0083\n",
      "Epoch [25/500], Loss: 0.0089\n",
      "Epoch [26/500], Loss: 0.0073\n",
      "Epoch [27/500], Loss: 0.0053\n",
      "Epoch [28/500], Loss: 0.0048\n",
      "Epoch [29/500], Loss: 0.0043\n",
      "Epoch [30/500], Loss: 0.0040\n",
      "Epoch [31/500], Loss: 0.0034\n",
      "Epoch [32/500], Loss: 0.0027\n",
      "Epoch [33/500], Loss: 0.0016\n",
      "Epoch [34/500], Loss: 0.0004\n",
      "Epoch [35/500], Loss: 0.0002\n",
      "Epoch [36/500], Loss: 0.0002\n",
      "Epoch [37/500], Loss: 0.0002\n",
      "Epoch [38/500], Loss: 0.0067\n",
      "Epoch [39/500], Loss: 0.0180\n",
      "Epoch [40/500], Loss: 0.0259\n",
      "Epoch [41/500], Loss: 0.0022\n",
      "Epoch [42/500], Loss: 0.0032\n",
      "Epoch [43/500], Loss: 0.0003\n",
      "Epoch [44/500], Loss: 0.0006\n",
      "Epoch [45/500], Loss: 0.0007\n",
      "Epoch [46/500], Loss: 0.0001\n",
      "Epoch [47/500], Loss: 0.0000\n",
      "Epoch [48/500], Loss: 0.0008\n",
      "Epoch [49/500], Loss: 0.0057\n",
      "Epoch [50/500], Loss: 0.0056\n",
      "Epoch [51/500], Loss: 0.0005\n",
      "Epoch [52/500], Loss: 0.0030\n",
      "Epoch [53/500], Loss: 0.0075\n",
      "Epoch [54/500], Loss: 0.0046\n",
      "Epoch [55/500], Loss: 0.0035\n",
      "Epoch [56/500], Loss: 0.0001\n",
      "Epoch [57/500], Loss: 0.0051\n",
      "Epoch [58/500], Loss: 0.0029\n",
      "Epoch [59/500], Loss: 0.0026\n",
      "Epoch [60/500], Loss: 0.0005\n",
      "Epoch [61/500], Loss: 0.0000\n",
      "Epoch [62/500], Loss: 0.0001\n",
      "Epoch [63/500], Loss: 0.0001\n",
      "Epoch [64/500], Loss: 0.0001\n",
      "Epoch [65/500], Loss: 0.0001\n",
      "Epoch [66/500], Loss: 0.0000\n",
      "Epoch [67/500], Loss: 0.0000\n",
      "Epoch [68/500], Loss: 0.0002\n",
      "Epoch [69/500], Loss: 0.0007\n",
      "Epoch [70/500], Loss: 0.0011\n",
      "Epoch [71/500], Loss: 0.0012\n",
      "Epoch [72/500], Loss: 0.0010\n",
      "Epoch [73/500], Loss: 0.0005\n",
      "Epoch [74/500], Loss: 0.0002\n",
      "Epoch [75/500], Loss: 0.0001\n",
      "Epoch [76/500], Loss: 0.0000\n",
      "Epoch [77/500], Loss: 0.0000\n",
      "Epoch [78/500], Loss: 0.0000\n",
      "Epoch [79/500], Loss: 0.0001\n",
      "Epoch [80/500], Loss: 0.0002\n",
      "Epoch [81/500], Loss: 0.0005\n",
      "Epoch [82/500], Loss: 0.0010\n",
      "Epoch [83/500], Loss: 0.0018\n",
      "Epoch [84/500], Loss: 0.0007\n",
      "Epoch [85/500], Loss: 0.0012\n",
      "Epoch [86/500], Loss: 0.0049\n",
      "Epoch [87/500], Loss: 0.0003\n",
      "Epoch [88/500], Loss: 0.0001\n",
      "Epoch [89/500], Loss: 0.0025\n",
      "Epoch [90/500], Loss: 0.0013\n",
      "Epoch [91/500], Loss: 0.0001\n",
      "Epoch [92/500], Loss: 0.0001\n",
      "Epoch [93/500], Loss: 0.0010\n",
      "Epoch [94/500], Loss: 0.0000\n",
      "Epoch [95/500], Loss: 0.0020\n",
      "Epoch [96/500], Loss: 0.0003\n",
      "Epoch [97/500], Loss: 0.0131\n",
      "Epoch [98/500], Loss: 0.0003\n",
      "Epoch [99/500], Loss: 0.0017\n",
      "Epoch [100/500], Loss: 0.0002\n",
      "Epoch [101/500], Loss: 0.0000\n",
      "Epoch [102/500], Loss: 0.0000\n",
      "Epoch [103/500], Loss: 0.0004\n",
      "Epoch [104/500], Loss: 0.0017\n",
      "Epoch [105/500], Loss: 0.0047\n",
      "Epoch [106/500], Loss: 0.0095\n",
      "Epoch [107/500], Loss: 0.0045\n",
      "Epoch [108/500], Loss: 0.0045\n",
      "Epoch [109/500], Loss: 0.0018\n",
      "Epoch [110/500], Loss: 0.0005\n",
      "Epoch [111/500], Loss: 0.0001\n",
      "Epoch [112/500], Loss: 0.0001\n",
      "Epoch [113/500], Loss: 0.0002\n",
      "Epoch [114/500], Loss: 0.0004\n",
      "Epoch [115/500], Loss: 0.0004\n",
      "Epoch [116/500], Loss: 0.0002\n",
      "Epoch [117/500], Loss: 0.0000\n",
      "Epoch [118/500], Loss: 0.0000\n",
      "Epoch [119/500], Loss: 0.0001\n",
      "Epoch [120/500], Loss: 0.0030\n",
      "Epoch [121/500], Loss: 0.0084\n",
      "Epoch [122/500], Loss: 0.0051\n",
      "Epoch [123/500], Loss: 0.0042\n",
      "Epoch [124/500], Loss: 0.0004\n",
      "Epoch [125/500], Loss: 0.0003\n",
      "Epoch [126/500], Loss: 0.0192\n",
      "Epoch [127/500], Loss: 0.0000\n",
      "Epoch [128/500], Loss: 0.0061\n",
      "Epoch [129/500], Loss: 0.0000\n",
      "Epoch [130/500], Loss: 0.0004\n",
      "Epoch [131/500], Loss: 0.0019\n",
      "Epoch [132/500], Loss: 0.0020\n",
      "Epoch [133/500], Loss: 0.0012\n",
      "Epoch [134/500], Loss: 0.0003\n",
      "Epoch [135/500], Loss: 0.0000\n",
      "Epoch [136/500], Loss: 0.0003\n",
      "Epoch [137/500], Loss: 0.0038\n",
      "Epoch [138/500], Loss: 0.0009\n",
      "Epoch [139/500], Loss: 0.0000\n",
      "Epoch [140/500], Loss: 0.0004\n",
      "Epoch [141/500], Loss: 0.0065\n",
      "Epoch [142/500], Loss: 0.0117\n",
      "Epoch [143/500], Loss: 0.0004\n",
      "Epoch [144/500], Loss: 0.0092\n",
      "Epoch [145/500], Loss: 0.0001\n",
      "Epoch [146/500], Loss: 0.0009\n",
      "Epoch [147/500], Loss: 0.0000\n",
      "Epoch [148/500], Loss: 0.0001\n",
      "Epoch [149/500], Loss: 0.0007\n",
      "Epoch [150/500], Loss: 0.0021\n",
      "Epoch [151/500], Loss: 0.0021\n",
      "Epoch [152/500], Loss: 0.0061\n",
      "Epoch [153/500], Loss: 0.0019\n",
      "Epoch [154/500], Loss: 0.0045\n",
      "Epoch [155/500], Loss: 0.0026\n",
      "Epoch [156/500], Loss: 0.0000\n",
      "Epoch [157/500], Loss: 0.0000\n",
      "Epoch [158/500], Loss: 0.0003\n",
      "Epoch [159/500], Loss: 0.0003\n",
      "Epoch [160/500], Loss: 0.0004\n",
      "Epoch [161/500], Loss: 0.0005\n",
      "Epoch [162/500], Loss: 0.0006\n",
      "Epoch [163/500], Loss: 0.0003\n",
      "Epoch [164/500], Loss: 0.0026\n",
      "Epoch [165/500], Loss: 0.0002\n",
      "Epoch [166/500], Loss: 0.0001\n",
      "Epoch [167/500], Loss: 0.0008\n",
      "Epoch [168/500], Loss: 0.0022\n",
      "Epoch [169/500], Loss: 0.0000\n",
      "Epoch [170/500], Loss: 0.0018\n",
      "Epoch [171/500], Loss: 0.0055\n",
      "Epoch [172/500], Loss: 0.0003\n",
      "Epoch [173/500], Loss: 0.0000\n",
      "Epoch [174/500], Loss: 0.0050\n",
      "Epoch [175/500], Loss: 0.0025\n",
      "Epoch [176/500], Loss: 0.0003\n",
      "Epoch [177/500], Loss: 0.0008\n",
      "Epoch [178/500], Loss: 0.0017\n",
      "Epoch [179/500], Loss: 0.0000\n",
      "Epoch [180/500], Loss: 0.0019\n",
      "Epoch [181/500], Loss: 0.0001\n",
      "Epoch [182/500], Loss: 0.0000\n",
      "Epoch [183/500], Loss: 0.0001\n",
      "Epoch [184/500], Loss: 0.0003\n",
      "Epoch [185/500], Loss: 0.0003\n",
      "Epoch [186/500], Loss: 0.0001\n",
      "Epoch [187/500], Loss: 0.0017\n",
      "Epoch [188/500], Loss: 0.0021\n",
      "Epoch [189/500], Loss: 0.0009\n",
      "Epoch [190/500], Loss: 0.0001\n",
      "Epoch [191/500], Loss: 0.0008\n",
      "Epoch [192/500], Loss: 0.0000\n",
      "Epoch [193/500], Loss: 0.0004\n",
      "Epoch [194/500], Loss: 0.0001\n",
      "Epoch [195/500], Loss: 0.0002\n",
      "Epoch [196/500], Loss: 0.0001\n",
      "Epoch [197/500], Loss: 0.0004\n",
      "Epoch [198/500], Loss: 0.0022\n",
      "Epoch [199/500], Loss: 0.0009\n",
      "Epoch [200/500], Loss: 0.0002\n",
      "Epoch [201/500], Loss: 0.0505\n",
      "Epoch [202/500], Loss: 0.7340\n",
      "Epoch [203/500], Loss: 0.0394\n",
      "Epoch [204/500], Loss: 0.0000\n",
      "Epoch [205/500], Loss: 0.0001\n",
      "Epoch [206/500], Loss: 0.0161\n",
      "Epoch [207/500], Loss: 0.0003\n",
      "Epoch [208/500], Loss: 0.0233\n",
      "Epoch [209/500], Loss: 0.0770\n",
      "Epoch [210/500], Loss: 0.0126\n",
      "Epoch [211/500], Loss: 0.0002\n",
      "Epoch [212/500], Loss: 0.0004\n",
      "Epoch [213/500], Loss: 0.0000\n",
      "Epoch [214/500], Loss: 0.0004\n",
      "Epoch [215/500], Loss: 0.0018\n",
      "Epoch [216/500], Loss: 0.0010\n",
      "Epoch [217/500], Loss: 0.0000\n",
      "Epoch [218/500], Loss: 0.0000\n",
      "Epoch [219/500], Loss: 0.0052\n",
      "Epoch [220/500], Loss: 0.0120\n",
      "Epoch [221/500], Loss: 0.0023\n",
      "Epoch [222/500], Loss: 0.0000\n",
      "Epoch [223/500], Loss: 0.0005\n",
      "Epoch [224/500], Loss: 0.0008\n",
      "Epoch [225/500], Loss: 0.0007\n",
      "Epoch [226/500], Loss: 0.0005\n",
      "Epoch [227/500], Loss: 0.0003\n",
      "Epoch [228/500], Loss: 0.0000\n",
      "Epoch [229/500], Loss: 0.0008\n",
      "Epoch [230/500], Loss: 0.0008\n",
      "Epoch [231/500], Loss: 0.0000\n",
      "Epoch [232/500], Loss: 0.0000\n",
      "Epoch [233/500], Loss: 0.0001\n",
      "Epoch [234/500], Loss: 0.0000\n",
      "Epoch [235/500], Loss: 0.0000\n",
      "Epoch [236/500], Loss: 0.0000\n",
      "Epoch [237/500], Loss: 0.0000\n",
      "Epoch [238/500], Loss: 0.0001\n",
      "Epoch [239/500], Loss: 0.0009\n",
      "Epoch [240/500], Loss: 0.0049\n",
      "Epoch [241/500], Loss: 0.0030\n",
      "Epoch [242/500], Loss: 0.0009\n",
      "Epoch [243/500], Loss: 0.0011\n",
      "Epoch [244/500], Loss: 0.0000\n",
      "Epoch [245/500], Loss: 0.0014\n",
      "Epoch [246/500], Loss: 0.0039\n",
      "Epoch [247/500], Loss: 0.0014\n",
      "Epoch [248/500], Loss: 0.0019\n",
      "Epoch [249/500], Loss: 0.0000\n",
      "Epoch [250/500], Loss: 0.0028\n",
      "Epoch [251/500], Loss: 0.0016\n",
      "Epoch [252/500], Loss: 0.0044\n",
      "Epoch [253/500], Loss: 0.0016\n",
      "Epoch [254/500], Loss: 0.0003\n",
      "Epoch [255/500], Loss: 0.0067\n",
      "Epoch [256/500], Loss: 0.0000\n",
      "Epoch [257/500], Loss: 0.0005\n",
      "Epoch [258/500], Loss: 0.0015\n",
      "Epoch [259/500], Loss: 0.0008\n",
      "Epoch [260/500], Loss: 0.0022\n",
      "Epoch [261/500], Loss: 0.0152\n",
      "Epoch [262/500], Loss: 0.0029\n",
      "Epoch [263/500], Loss: 0.0001\n",
      "Epoch [264/500], Loss: 0.0000\n",
      "Epoch [265/500], Loss: 0.0001\n",
      "Epoch [266/500], Loss: 0.0001\n",
      "Epoch [267/500], Loss: 0.0000\n",
      "Epoch [268/500], Loss: 0.0015\n",
      "Epoch [269/500], Loss: 0.0071\n",
      "Epoch [270/500], Loss: 0.0047\n",
      "Epoch [271/500], Loss: 0.0011\n",
      "Epoch [272/500], Loss: 0.0017\n",
      "Epoch [273/500], Loss: 0.0014\n",
      "Epoch [274/500], Loss: 0.0007\n",
      "Epoch [275/500], Loss: 0.0002\n",
      "Epoch [276/500], Loss: 0.0002\n",
      "Epoch [277/500], Loss: 0.0002\n",
      "Epoch [278/500], Loss: 0.0006\n",
      "Epoch [279/500], Loss: 0.0010\n",
      "Epoch [280/500], Loss: 0.0003\n",
      "Epoch [281/500], Loss: 0.0000\n",
      "Epoch [282/500], Loss: 0.0002\n",
      "Epoch [283/500], Loss: 0.0017\n",
      "Epoch [284/500], Loss: 0.0054\n",
      "Epoch [285/500], Loss: 0.0003\n",
      "Epoch [286/500], Loss: 0.0035\n",
      "Epoch [287/500], Loss: 0.0013\n",
      "Epoch [288/500], Loss: 0.0210\n",
      "Epoch [289/500], Loss: 0.0003\n",
      "Epoch [290/500], Loss: 0.0045\n",
      "Epoch [291/500], Loss: 0.0007\n",
      "Epoch [292/500], Loss: 0.0001\n",
      "Epoch [293/500], Loss: 0.0000\n",
      "Epoch [294/500], Loss: 0.0001\n",
      "Epoch [295/500], Loss: 0.0000\n",
      "Epoch [296/500], Loss: 0.0006\n",
      "Epoch [297/500], Loss: 0.0002\n",
      "Epoch [298/500], Loss: 0.0000\n",
      "Epoch [299/500], Loss: 0.0016\n",
      "Epoch [300/500], Loss: 0.0065\n",
      "Epoch [301/500], Loss: 0.0015\n",
      "Epoch [302/500], Loss: 0.0071\n",
      "Epoch [303/500], Loss: 0.0079\n",
      "Epoch [304/500], Loss: 0.0051\n",
      "Epoch [305/500], Loss: 0.0080\n",
      "Epoch [306/500], Loss: 0.0010\n",
      "Epoch [307/500], Loss: 0.0011\n",
      "Epoch [308/500], Loss: 0.0010\n",
      "Epoch [309/500], Loss: 0.0001\n",
      "Epoch [310/500], Loss: 0.0013\n",
      "Epoch [311/500], Loss: 0.0056\n",
      "Epoch [312/500], Loss: 0.0021\n",
      "Epoch [313/500], Loss: 0.0004\n",
      "Epoch [314/500], Loss: 0.0006\n",
      "Epoch [315/500], Loss: 0.0000\n",
      "Epoch [316/500], Loss: 0.0006\n",
      "Epoch [317/500], Loss: 0.0047\n",
      "Epoch [318/500], Loss: 0.0024\n",
      "Epoch [319/500], Loss: 0.0009\n",
      "Epoch [320/500], Loss: 0.0029\n",
      "Epoch [321/500], Loss: 0.0000\n",
      "Epoch [322/500], Loss: 0.0005\n",
      "Epoch [323/500], Loss: 0.0000\n",
      "Epoch [324/500], Loss: 0.0000\n",
      "Epoch [325/500], Loss: 0.0001\n",
      "Epoch [326/500], Loss: 0.0002\n",
      "Epoch [327/500], Loss: 0.0001\n",
      "Epoch [328/500], Loss: 0.0006\n",
      "Epoch [329/500], Loss: 0.0003\n",
      "Epoch [330/500], Loss: 0.0060\n",
      "Epoch [331/500], Loss: 0.0001\n",
      "Epoch [332/500], Loss: 0.0004\n",
      "Epoch [333/500], Loss: 0.0000\n",
      "Epoch [334/500], Loss: 0.0008\n",
      "Epoch [335/500], Loss: 0.0087\n",
      "Epoch [336/500], Loss: 0.0066\n",
      "Epoch [337/500], Loss: 0.0015\n",
      "Epoch [338/500], Loss: 0.0094\n",
      "Epoch [339/500], Loss: 0.0028\n",
      "Epoch [340/500], Loss: 0.0197\n",
      "Epoch [341/500], Loss: 0.0002\n",
      "Epoch [342/500], Loss: 0.0015\n",
      "Epoch [343/500], Loss: 0.0003\n",
      "Epoch [344/500], Loss: 0.0000\n",
      "Epoch [345/500], Loss: 0.0000\n",
      "Epoch [346/500], Loss: 0.0001\n",
      "Epoch [347/500], Loss: 0.0000\n",
      "Epoch [348/500], Loss: 0.0000\n",
      "Epoch [349/500], Loss: 0.0000\n",
      "Epoch [350/500], Loss: 0.0001\n",
      "Epoch [351/500], Loss: 0.0000\n",
      "Epoch [352/500], Loss: 0.0002\n",
      "Epoch [353/500], Loss: 0.0005\n",
      "Epoch [354/500], Loss: 0.0019\n",
      "Epoch [355/500], Loss: 0.0014\n",
      "Epoch [356/500], Loss: 0.0009\n",
      "Epoch [357/500], Loss: 0.0011\n",
      "Epoch [358/500], Loss: 0.0039\n",
      "Epoch [359/500], Loss: 0.0029\n",
      "Epoch [360/500], Loss: 0.0019\n",
      "Epoch [361/500], Loss: 0.0059\n",
      "Epoch [362/500], Loss: 0.0063\n",
      "Epoch [363/500], Loss: 0.0001\n",
      "Epoch [364/500], Loss: 0.0007\n",
      "Epoch [365/500], Loss: 0.0027\n",
      "Epoch [366/500], Loss: 0.0022\n",
      "Epoch [367/500], Loss: 0.0001\n",
      "Epoch [368/500], Loss: 0.0057\n",
      "Epoch [369/500], Loss: 0.0052\n",
      "Epoch [370/500], Loss: 0.0066\n",
      "Epoch [371/500], Loss: 0.0001\n",
      "Epoch [372/500], Loss: 0.0027\n",
      "Epoch [373/500], Loss: 0.0047\n",
      "Epoch [374/500], Loss: 0.0000\n",
      "Epoch [375/500], Loss: 0.0002\n",
      "Epoch [376/500], Loss: 0.0014\n",
      "Epoch [377/500], Loss: 0.0035\n",
      "Epoch [378/500], Loss: 0.0013\n",
      "Epoch [379/500], Loss: 0.0008\n",
      "Epoch [380/500], Loss: 0.0027\n",
      "Epoch [381/500], Loss: 0.0000\n",
      "Epoch [382/500], Loss: 0.0007\n",
      "Epoch [383/500], Loss: 0.0019\n",
      "Epoch [384/500], Loss: 0.0000\n",
      "Epoch [385/500], Loss: 0.0035\n",
      "Epoch [386/500], Loss: 0.0003\n",
      "Epoch [387/500], Loss: 0.0003\n",
      "Epoch [388/500], Loss: 0.0052\n",
      "Epoch [389/500], Loss: 0.0043\n",
      "Epoch [390/500], Loss: 0.0022\n",
      "Epoch [391/500], Loss: 0.0000\n",
      "Epoch [392/500], Loss: 0.0120\n",
      "Epoch [393/500], Loss: 0.0002\n",
      "Epoch [394/500], Loss: 0.0035\n",
      "Epoch [395/500], Loss: 0.0002\n",
      "Epoch [396/500], Loss: 0.0013\n",
      "Epoch [397/500], Loss: 0.0052\n",
      "Epoch [398/500], Loss: 0.0042\n",
      "Epoch [399/500], Loss: 0.0022\n",
      "Epoch [400/500], Loss: 0.0049\n",
      "Epoch [401/500], Loss: 0.0052\n",
      "Epoch [402/500], Loss: 0.0066\n",
      "Epoch [403/500], Loss: 0.0009\n",
      "Epoch [404/500], Loss: 0.0001\n",
      "Epoch [405/500], Loss: 0.0001\n",
      "Epoch [406/500], Loss: 0.0003\n",
      "Epoch [407/500], Loss: 0.0012\n",
      "Epoch [408/500], Loss: 0.0030\n",
      "Epoch [409/500], Loss: 0.0028\n",
      "Epoch [410/500], Loss: 0.0016\n",
      "Epoch [411/500], Loss: 0.0001\n",
      "Epoch [412/500], Loss: 0.0007\n",
      "Epoch [413/500], Loss: 0.0000\n",
      "Epoch [414/500], Loss: 0.0002\n",
      "Epoch [415/500], Loss: 0.0029\n",
      "Epoch [416/500], Loss: 0.0053\n",
      "Epoch [417/500], Loss: 0.0023\n",
      "Epoch [418/500], Loss: 0.0023\n",
      "Epoch [419/500], Loss: 0.0000\n",
      "Epoch [420/500], Loss: 0.0017\n",
      "Epoch [421/500], Loss: 0.0000\n",
      "Epoch [422/500], Loss: 0.0018\n",
      "Epoch [423/500], Loss: 0.0008\n",
      "Epoch [424/500], Loss: 0.0024\n",
      "Epoch [425/500], Loss: 0.0027\n",
      "Epoch [426/500], Loss: 0.0014\n",
      "Epoch [427/500], Loss: 0.0031\n",
      "Epoch [428/500], Loss: 0.0005\n",
      "Epoch [429/500], Loss: 0.0060\n",
      "Epoch [430/500], Loss: 0.0006\n",
      "Epoch [431/500], Loss: 0.0031\n",
      "Epoch [432/500], Loss: 0.0011\n",
      "Epoch [433/500], Loss: 0.0000\n",
      "Epoch [434/500], Loss: 0.0006\n",
      "Epoch [435/500], Loss: 0.0015\n",
      "Epoch [436/500], Loss: 0.0001\n",
      "Epoch [437/500], Loss: 0.0005\n",
      "Epoch [438/500], Loss: 0.0000\n",
      "Epoch [439/500], Loss: 0.0003\n",
      "Epoch [440/500], Loss: 0.0002\n",
      "Epoch [441/500], Loss: 0.0005\n",
      "Epoch [442/500], Loss: 0.0000\n",
      "Epoch [443/500], Loss: 0.0013\n",
      "Epoch [444/500], Loss: 0.0035\n",
      "Epoch [445/500], Loss: 0.0003\n",
      "Epoch [446/500], Loss: 0.0001\n",
      "Epoch [447/500], Loss: 0.0005\n",
      "Epoch [448/500], Loss: 0.0001\n",
      "Epoch [449/500], Loss: 0.0025\n",
      "Epoch [450/500], Loss: 0.0020\n",
      "Epoch [451/500], Loss: 0.0002\n",
      "Epoch [452/500], Loss: 0.0013\n",
      "Epoch [453/500], Loss: 0.0000\n",
      "Epoch [454/500], Loss: 0.0000\n",
      "Epoch [455/500], Loss: 0.0003\n",
      "Epoch [456/500], Loss: 0.0000\n",
      "Epoch [457/500], Loss: 0.0012\n",
      "Epoch [458/500], Loss: 0.0001\n",
      "Epoch [459/500], Loss: 0.0006\n",
      "Epoch [460/500], Loss: 0.0004\n",
      "Epoch [461/500], Loss: 0.0000\n",
      "Epoch [462/500], Loss: 0.0021\n",
      "Epoch [463/500], Loss: 0.0003\n",
      "Epoch [464/500], Loss: 0.0002\n",
      "Epoch [465/500], Loss: 0.0015\n",
      "Epoch [466/500], Loss: 0.0003\n",
      "Epoch [467/500], Loss: 0.0018\n",
      "Epoch [468/500], Loss: 0.0020\n",
      "Epoch [469/500], Loss: 0.0002\n",
      "Epoch [470/500], Loss: 0.0060\n",
      "Epoch [471/500], Loss: 0.0005\n",
      "Epoch [472/500], Loss: 0.0000\n",
      "Epoch [473/500], Loss: 0.0050\n",
      "Epoch [474/500], Loss: 0.0009\n",
      "Epoch [475/500], Loss: 0.0001\n",
      "Epoch [476/500], Loss: 0.0003\n",
      "Epoch [477/500], Loss: 0.0002\n",
      "Epoch [478/500], Loss: 0.0001\n",
      "Epoch [479/500], Loss: 0.0005\n",
      "Epoch [480/500], Loss: 0.0000\n",
      "Epoch [481/500], Loss: 0.0004\n",
      "Epoch [482/500], Loss: 0.0012\n",
      "Epoch [483/500], Loss: 0.0001\n",
      "Epoch [484/500], Loss: 0.0031\n",
      "Epoch [485/500], Loss: 0.0006\n",
      "Epoch [486/500], Loss: 0.0003\n",
      "Epoch [487/500], Loss: 0.0022\n",
      "Epoch [488/500], Loss: 0.0019\n",
      "Epoch [489/500], Loss: 0.0013\n",
      "Epoch [490/500], Loss: 0.0000\n",
      "Epoch [491/500], Loss: 0.0001\n",
      "Epoch [492/500], Loss: 0.0022\n",
      "Epoch [493/500], Loss: 0.0008\n",
      "Epoch [494/500], Loss: 0.0012\n",
      "Epoch [495/500], Loss: 0.0057\n",
      "Epoch [496/500], Loss: 0.0047\n",
      "Epoch [497/500], Loss: 0.0005\n",
      "Epoch [498/500], Loss: 0.0001\n",
      "Epoch [499/500], Loss: 0.0000\n",
      "Epoch [500/500], Loss: 0.0000\n"
     ]
    }
   ],
   "source": [
    "# Initialize the model, loss function, and optimizer\n",
    "model = RNNModel()\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "# Training loop\n",
    "epochs = 500\n",
    "for epoch in range(epochs):\n",
    "    for sequences, labels in zip(X_seq, y_seq):\n",
    "        sequences = sequences.unsqueeze(0)  # Add batch dimension\n",
    "        labels = labels.unsqueeze(0)  # Add batch dimension\n",
    "\n",
    "        # Forward pass\n",
    "        outputs = model(sequences)\n",
    "        loss = criterion(outputs, labels)\n",
    "\n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    print(f\"Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predictions for new sequence: [[2.7525057792663574]]\n"
     ]
    }
   ],
   "source": [
    "# Testing on new data\n",
    "X_test = torch.linspace(4 * 3.14159, 5 * 3.14159, steps=10).unsqueeze(1)\n",
    "\n",
    "# Reshape to (batch_size, sequence_length, input_size)\n",
    "X_test = X_test.unsqueeze(0)  # Add batch dimension, shape becomes (1, 10, 1)\n",
    "\n",
    "with torch.no_grad():\n",
    "    predictions = model(X_test)\n",
    "    print(f\"Predictions for new sequence: {predictions.tolist()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
